# -*- coding: utf-8 -*-
"""
Created on Tue Jul 21 10:08:50 2020

@author: 81283
"""
from environment.env import next_state_reward_end
from .value import q_dict
from .action import epsilon_action, greedy_action


# 折扣因子, 超参数
discount=0.95      # 0.95

# 学习率， 超参数
learning_rate=0.1


def q_learning(state, no_cliff=False):
    # epsilon-greedy 策略选择当前状态的动作
    action, q = epsilon_action(state)
    
    # end_flag 为true则next_state为终止状态
    next_state, reward, end_flag=next_state_reward_end(state, action, no_cliff)   

    # 选择下一状态的评估动作
    next_action, next_q=greedy_action(next_state)
 
    #对当前动作值的估计 
    # estimate=reward+discount*q_value[(next_state, next_action)]
    estimate=reward+discount*next_q
    
    #TD error
    # td_error=estimate-q_value[(state, action)]
    td_error = estimate - q
    
    #学习到的新当前动作值
    q_dict[(state, action)]+=learning_rate*td_error
    
    return next_state, reward, end_flag 


def sarsa_learning(state, no_cliff=False):
    # epsilon-greedy 策略选择当前状态的动作
    action, q = epsilon_action(state)
    
    # end_flag 为true则next_state为终止状态
    next_state, reward, end_flag=next_state_reward_end(state, action, no_cliff)   

    # 选择下一状态的评估动作
    next_action, next_q=epsilon_action(next_state)
 
    #对当前动作值的估计 
    # estimate=reward+discount*q_value[(next_state, next_action)]
    estimate=reward+discount*next_q
    
    #TD error
    # td_error=estimate-q_value[(state, action)]
    td_error = estimate - q
    
    #学习到的新当前动作值
    q_dict[(state, action)]+=learning_rate*td_error
    
    return next_state, reward, end_flag 





# 最终确定性策略下的奖励
def do_final_reward(no_cliff=False):
    reward_sum=0
    state=(0, 0)
    visit_hist=[state] #访问的状态历史记录
    solved=True # 返回是否找到终点的标志
    
    for i in range(10000):#超过10000steps认为不能走到终点，进入死循环
        action, _ = greedy_action(state)  # 贪婪策略选择动作 
    
        # end_flag 为true则next_state为终止状态
        next_state, reward, end_flag=next_state_reward_end(state, action, no_cliff)   
        
        reward_sum+=reward
        visit_hist.append(next_state) # 将访问的节点加入历史记录
        
        if end_flag==True:
            break
        else:
            state=next_state

    if (i+1)==10000:
        solved=False
        return reward_sum, solved, visit_hist # 没有找到终点 
    else:
        return reward_sum, solved, visit_hist # 可以找到终点